include("Filter_fast.jl")
include("StarW.jl")
include("h_score3.jl")
include("mx_func.jl")
include("INC3.jl")
include("Filter_fast.jl")
include("CLidx.jl")
include("Clique_sm.jl")

using SparseArrays
using LinearAlgebra
using Clustering
using NearestNeighbors
using Distances
using Laplacians
using Arpack
using Statistics
using DelimitedFiles
using StatsBase
using Laplacians#master
using Random

function decomposition(A, NF, alpha, target_nodes)

    ar_mat = Any[]

    fdnz = findnz(triu(A, 1))

    r1 = fdnz[1]
    c1 = fdnz[2]

    global ar = Any[]

    for ii = 1:length(r1)
        nd1 = [r1[ii], c1[ii]]
        push!(ar, sort(nd1))
    end

    ar_org = copy(ar)

    W = fdnz[3] .* 1.0

    ar_new = Any[]

    mx = mx_func(ar)

    idx_mat = Any[]

    Neff = zeros(Float64, mx)

    while true  # Keep coarsening until the target node count is reached

        # Current number of unique nodes in the graph
        current_nodes = length(unique(vcat(ar...)))

        # Stop if the number of nodes is equal or less than the target
        if current_nodes <= target_nodes
            println("Target node number reached or exceeded. Stopping coarsening.")
            break
        end

        mx = mx_func(ar)

        ## star expansion
        A = StarW(ar, W)

        ## computing the smoothed vectors
        initial = 0
        SmS = 100
        interval = 1
        Nrv = 1
        RedR = 1
        Nsm = Int((SmS - initial) / interval)
        Ntot = Nrv * Nsm

        Qvec = zeros(Float64, 0)
        Eratio = zeros(Float64, length(ar), Ntot)
        SV = zeros(Float64, mx, Ntot)

        for ii = 1:Nrv

            sm = zeros(mx, Nsm)
            Random.seed!(1); randstring()

            rv = (rand(Float64, size(A, 1), 1) .- 0.5) .* 2

            sm = Filter_fast(rv, SmS, A, mx, initial, interval, Nsm)

            SV[:, (ii - 1) * Nsm + 1 : ii * Nsm] = sm

        end

        ## Make all the smoothed vectors orthogonal to each other
        QR = qr(SV)
        SV = Matrix(QR.Q)

        ## Computing the ratios using all the smoothed vectors
        for jj = 1:size(SV, 2)

            hscore = h_score3(ar, SV[:, jj])

            Eratio[:, jj] = hscore ./ sum(hscore)

        end

        ## Approximating the effective resistance of edges by averaging ratios
        Evec = sum(Eratio, dims = 2) ./ size(SV, 2)

        ## Node features
        global NFd = zeros(Float64, 0)
        for ii = 1:length(ar)

            nv1 = ar[ii]

            ds = euclidean(NF[:, nv1[1]], NF[:, nv1[2]])  # eq1

            append!(NFd, ds)

        end

        ## Adding the node feature distance to effective resistances
        global Evec = Evec + (alpha * NFd)  # eq1

        # Adding the effective resistance of super nodes from previous levels
        @inbounds for kk = 1:length(ar)

            nd2 = ar[kk]

            Evec[kk] = Evec[kk] + sum(Neff[nd2])

        end

        ## Normalizing the ERs
        P = Evec ./ maximum(Evec)

        ## Choosing a ratio of all the hyperedges
        Nsample = length(findall(x -> x < 1e3, Evec))

        PosP = sortperm(P[:, 1])

        ## Increasing the weight of the hyperedges with small ERs
        W[PosP[1:Nsample]] = W[PosP[1:Nsample]] .* (1 .+ 1 ./ P[PosP[1:Nsample]])

        ## Selecting the hyperedges with higher weights for contraction
        Pos = sortperm(W, rev = true)

        ## Hyperedge contraction
        flag = falses(mx)

        global val = 1
        global idx = zeros(Int, mx)
        Hcard = zeros(Int, 0)
        Neff_new = zeros(Float64, 0)
        global selF = zeros(Int, 0)

        @inbounds for ii = 1:Nsample

            nd = ar[Pos[ii]]

            fg = flag[nd]
            fd1 = findall(x -> x == 0, fg)

            if length(fd1) > 1

                nd = nd[fd1]

                append!(selF, nd[1])
                idx[nd] .= val
                flag[nd] .= 1
                append!(Hcard, length(ar[ii]))
                val += 1

                ## Check if the number of nodes is still above the target
                current_nodes = length(unique(idx))
                if current_nodes <= target_nodes
                    println("Node count is below the target, stopping contraction.")
                    break
                end

                ## creating the super node weights
                new_val = Evec[Pos[ii]] + sum(Neff[nd])
                append!(Neff_new, new_val)

            end

        end

        ## indexing the isolated nodes
        fdz = findall(x -> x == 0, idx)
        fdnz = findall(x -> x != 0, idx)

        V = vec(val:val + length(fdz) - 1)
        idx[fdz] = V

        append!(selF, fdz)

        ## Adding the weight of isolated nodes
        append!(Neff_new, Neff[fdz])

        push!(idx_mat, idx)
        push!(ar_mat, ar)

        ## generating the coarse hypergraph
        ar_new = Any[]

        @inbounds for ii = 1:length(ar)

            nd = ar[ii]
            nd_new = unique(idx[nd])

            push!(ar_new, sort(nd_new))

        end

        ## Keeping the edge weights of non-unique elements
        fdnu = unique(z -> ar_new[z], 1:length(ar_new))
        W2 = W[fdnu]

        ## removing the repeated hyperedges
        ar_new = ar_new[fdnu]

        ## removing hyperedges with cardinality of 1
        HH = INC3(ar_new)
        ss = sum(HH, dims = 2)
        fd1 = findall(x -> x == 1, ss[:, 1])
        deleteat!(ar_new, fd1)
        deleteat!(W2, fd1)

        ar = ar_new
        Neff = Neff_new
        W = W2

        global NF = NF[:, selF]

    end  # end of while loop

    CLidx(idx_mat)

    return Clique_sm(ar)
end
